In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random

Do preprocessing

In [3]:
#/hpf/largeprojects/MICe/mdagys/Cnp-GFP_Study/2019-06-10_labelled/raw
raw_dir = Path("raw")
raws = raw_dir.ls()
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
# D-R_Z were the initial ones to be labelled, kinda more sloppy.
# images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name and "D-R_Z" not in raw_path.name])
# labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name and "D-R_Z" not in raw_path.name])

processed_dir = Path("processed")
l=224
In [ ]:
random.seed(23)
empty = 0
popu = 0
cutoff=1

for image_path,label_path in zip(images,labels):
    image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
    label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)

    if image.shape != label.shape:
        raise ValueError(image_path.as_posix() + label_path.as_posix())
    i_max = image.shape[0]//l
    j_max = image.shape[1]//l

# If the cells were labelled as 255, or something else mistakenly, instead of 1.
    label[label!=0]=1

    for i in range(i_max):
        for j in range(j_max):
            cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
            cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]

            if (cropped_label!=0).any():
                popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            else:
                empty+=1
                if (random.random() < cutoff):
                    continue
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)

            cv.imwrite(cropped_image_path.as_posix(), cropped_image)
            cv.imwrite(cropped_label_path.as_posix(), cropped_label)
In [ ]:
print(popu)
print(empty)

Train NN

In [4]:
torch.cuda.set_device(2)
In [5]:
codes = ["NOT-CELL", "CELL"]
bs = 4
#bs=16 and l=224 will use ~7300MiB for resnet34  before unfreezing
#bs=4 and l=224 use ~11500MiB for resnet50 before unfreezing
In [6]:
transforms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_zoom = 1, #consider
    max_rotate = 45,
    max_lighting = None,
    max_warp = None,
    p_affine = 0.75,
    p_lighting = 0.75)
In [7]:
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())

src = (
    SegmentationItemList.from_folder(processed_dir)
    .filter_by_func(lambda fname:
                    'image' in Path(fname).name and "empty" not in Path(fname).name)
    .split_by_rand_pct(valid_pct=0.20, seed=1)
    .label_from_func(get_label_from_image, classes=codes)
)
data = (
    src.transform(transforms, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)
In [8]:
data.show_batch(2, figsize=(10,7))
In [8]:
# models.resnet34
model_path = Path("../../models")
learn = unet_learner(data, models.resnet50, metrics=partial(dice, iou=True))
learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([1,1]).cuda())
In [9]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [10]:
lr = 1e-5
learn.fit_one_cycle(25, lr)
epoch train_loss valid_loss dice time
0 0.037154 0.032490 0.004730 03:28
1 0.020648 0.019681 0.179485 03:24
2 0.019034 0.017940 0.282339 03:23
3 0.017589 0.016726 0.285386 03:22
4 0.015619 0.016057 0.240093 03:23
5 0.016768 0.014846 0.312353 03:26
6 0.015572 0.014523 0.373721 03:25
7 0.016004 0.014422 0.389043 03:23
8 0.015861 0.013650 0.348052 03:23
9 0.014372 0.014264 0.258737 03:23
10 0.014837 0.015317 0.292549 03:25
11 0.014176 0.013128 0.351677 03:22
12 0.014649 0.013312 0.346994 03:22
13 0.013667 0.013173 0.332092 03:22
14 0.013455 0.013370 0.348519 03:23
15 0.012634 0.012762 0.404465 03:23
16 0.014055 0.012980 0.369382 03:23
17 0.013946 0.012866 0.415134 03:23
18 0.013503 0.012561 0.391336 03:27
19 0.013357 0.012770 0.426071 03:22
20 0.014405 0.012483 0.372560 03:25
21 0.012053 0.012482 0.399669 03:25
22 0.013814 0.012482 0.392954 03:25
23 0.013321 0.012473 0.382243 03:23
24 0.013489 0.012439 0.389492 03:23
In [22]:
learn.save(model_path/"2019-07-02_RESNET50_IOU0.41_1stage")
In [ ]:
!jupyter nbconvert gfp-cnp-train-Copy2.ipynb --to html --output nbs/2019-06-26_RESNET50_IOU0.38_2stage
In [ ]:
learn.load(model_path/"2019-07-02_RESNET50_IOU0.41_1stage");
In [11]:
learn.unfreeze()
In [12]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [13]:
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(15, lrs)
epoch train_loss valid_loss dice time
0 0.012230 0.012469 0.385714 03:28
1 0.013415 0.012521 0.390801 03:28
2 0.013732 0.012494 0.404498 03:31
3 0.013076 0.012486 0.386176 03:30
4 0.012667 0.012626 0.414451 03:30
5 0.012152 0.012571 0.377130 03:27
6 0.012979 0.012592 0.382010 03:26
7 0.012267 0.012466 0.391705 03:27
8 0.013559 0.012479 0.399885 03:28
9 0.012752 0.012709 0.379596 03:28
10 0.012791 0.012410 0.399062 03:27
11 0.011300 0.012419 0.404317 03:27
12 0.013805 0.012487 0.386655 03:27
13 0.012888 0.012561 0.394792 03:25
14 0.014145 0.012490 0.388511 03:25
In [ ]:
learn.save(models_path/"2019-06-14_RESNET34_IOU0.25_2stage")
In [ ]:
learn.export(file = models_path/"2019-06-14_RESNET34_IOU0.25_2stage.pkl")

Check

In [ ]:
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])
In [14]:
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()
In [ ]:
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
#     print(torch.max(preds[0][i][1]))

# Image(preds[1][0]).show()
In [15]:
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
    N = learn.data.valid_ds.__len__()
else:
    raise ValueError()

xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]
In [ ]:
print(xs[0].px.shape)
print(ys[0].px.shape)
print(p0s[0].px.shape)
print(p1s[0].px.shape)
In [17]:
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [18]:
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
    plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Greys", alpha=1)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [17]:
learn.show_results(rows=16, ds_type=DatasetType.Train)
In [20]:
learn.show_results(rows=16)